{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据读取"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Program Files\\Anaconda3\\lib\\site-packages\\pandas\\compat\\_optional.py:106: UserWarning: Pandas requires version '1.2.1' or newer of 'bottleneck' (version '1.1.0' currently installed).\n",
      "  warnings.warn(msg, UserWarning)\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import sklearn\n",
    "from gensim.models import Word2Vec\n",
    "import matplotlib.pyplot as plt\n",
    "import jieba\n",
    "import matplotlib as mpl\n",
    "mpl.rcParams['font.sans-serif'] = ['SimHei']\n",
    "mpl.rcParams['font.serif'] = ['SimHei']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>类别名字</th>\n",
       "      <th>文本数据</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>娱乐</td>\n",
       "      <td>黎明喜当爹满面春风，乐基儿发福成大妈！离婚后两人差距这么大？</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>旅游</td>\n",
       "      <td>亚洲排名第一都市，北京GDP不及它一半，上海10年都赶超不了</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>国际</td>\n",
       "      <td>悠仁：日本明仁天皇孙辈唯一男丁堪称“独苗”，将来可能继位</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>娱乐</td>\n",
       "      <td>被意大利炮干掉的阿部规秀到底是个什么官？</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>财经</td>\n",
       "      <td>公司亏损时，股东如何维护自己的利益？</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  类别名字                            文本数据\n",
       "0   娱乐  黎明喜当爹满面春风，乐基儿发福成大妈！离婚后两人差距这么大？\n",
       "1   旅游  亚洲排名第一都市，北京GDP不及它一半，上海10年都赶超不了\n",
       "2   国际    悠仁：日本明仁天皇孙辈唯一男丁堪称“独苗”，将来可能继位\n",
       "3   娱乐            被意大利炮干掉的阿部规秀到底是个什么官？\n",
       "4   财经              公司亏损时，股东如何维护自己的利益？"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_path = \"../data/data.csv\"\n",
    "stopword_path = \"../data/stop_words.txt\"\n",
    "data = pd.read_csv(data_path)\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 去除证券类\n",
    "data = data[data[\"类别名字\"] != \"证券\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 分词去停用词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(stopword_path, \"r\", encoding=\"utf-8\") as f:\n",
    "    stopwords = f.readlines()\n",
    "\n",
    "stopwords = [i.strip() for i in stopwords]\n",
    "\n",
    "def segment(data):\n",
    "    data = jieba.cut(data)\n",
    "    data = [i for i in data if i not in stopwords]\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Dumping model to file cache C:\\Users\\jtsun\\AppData\\Local\\Temp\\jieba.cache\n",
      "Loading model cost 0.803 seconds.\n",
      "Prefix dict has been built succesfully.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>类别名字</th>\n",
       "      <th>文本数据</th>\n",
       "      <th>分词后的数据</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>娱乐</td>\n",
       "      <td>黎明喜当爹满面春风，乐基儿发福成大妈！离婚后两人差距这么大？</td>\n",
       "      <td>[黎明, 喜当爹, 满面春风, 乐基儿, 发福, 成, 大妈, 离婚, 两人, 差距, 大]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>旅游</td>\n",
       "      <td>亚洲排名第一都市，北京GDP不及它一半，上海10年都赶超不了</td>\n",
       "      <td>[亚洲, 排名, 第一, 都市, 北京, GDP, 不及, 它, 一半, 上海, 10, 年...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>国际</td>\n",
       "      <td>悠仁：日本明仁天皇孙辈唯一男丁堪称“独苗”，将来可能继位</td>\n",
       "      <td>[悠仁, 日本, 明, 仁天皇, 孙辈, 唯一, 男丁, 堪称, 独苗, 将来, 可能, 继位]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>娱乐</td>\n",
       "      <td>被意大利炮干掉的阿部规秀到底是个什么官？</td>\n",
       "      <td>[被, 意大利, 炮, 干掉, 阿部规秀, 到底, 是, 个, 什么, 官]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>财经</td>\n",
       "      <td>公司亏损时，股东如何维护自己的利益？</td>\n",
       "      <td>[公司, 亏损, 时, 股东, 如何, 维护, 自己, 利益]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  类别名字                            文本数据  \\\n",
       "0   娱乐  黎明喜当爹满面春风，乐基儿发福成大妈！离婚后两人差距这么大？   \n",
       "1   旅游  亚洲排名第一都市，北京GDP不及它一半，上海10年都赶超不了   \n",
       "2   国际    悠仁：日本明仁天皇孙辈唯一男丁堪称“独苗”，将来可能继位   \n",
       "3   娱乐            被意大利炮干掉的阿部规秀到底是个什么官？   \n",
       "4   财经              公司亏损时，股东如何维护自己的利益？   \n",
       "\n",
       "                                              分词后的数据  \n",
       "0     [黎明, 喜当爹, 满面春风, 乐基儿, 发福, 成, 大妈, 离婚, 两人, 差距, 大]  \n",
       "1  [亚洲, 排名, 第一, 都市, 北京, GDP, 不及, 它, 一半, 上海, 10, 年...  \n",
       "2   [悠仁, 日本, 明, 仁天皇, 孙辈, 唯一, 男丁, 堪称, 独苗, 将来, 可能, 继位]  \n",
       "3             [被, 意大利, 炮, 干掉, 阿部规秀, 到底, 是, 个, 什么, 官]  \n",
       "4                    [公司, 亏损, 时, 股东, 如何, 维护, 自己, 利益]  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data[\"分词后的数据\"] = data[\"文本数据\"].apply(segment)\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(294304,)\n",
      "['黎明', '喜当爹', '满面春风', '乐基儿', '发福', '成', '大妈', '离婚', '两人', '差距', '大']\n"
     ]
    }
   ],
   "source": [
    "data_all = data[\"分词后的数据\"].tolist()\n",
    "print(np.shape(data_all))\n",
    "print(data_all[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_word2vec(sentences,fname,min_count_num,embedding_dim):\n",
    "    if not os.path.exists(fname):\n",
    "        model = Word2Vec(sentences, min_count=min_count_num, size=embedding_dim)\n",
    "        model.save(fname)\n",
    "    else:\n",
    "        model = Word2Vec.load(fname)\n",
    "    \n",
    "    unk_vector = np.zeros(shape=[embedding_dim,])\n",
    "    \n",
    "    data_tensor = []\n",
    "    for sent in sentences:        \n",
    "        sent_matrix = []\n",
    "        for word in sent:\n",
    "            try:\n",
    "                temp_vector = model[word]\n",
    "            except KeyError:\n",
    "                temp_vector = unk_vector\n",
    "            sent_matrix.append(temp_vector)\n",
    "        sent_matrix = np.mean(sent_matrix,axis=0)\n",
    "        data_tensor.append(sent_matrix)\n",
    "        \n",
    "    return np.array(data_tensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Program Files\\Anaconda3\\lib\\site-packages\\gensim\\models\\base_any2vec.py:743: UserWarning: C extension not loaded, training will be slow. Install a C compiler and reinstall gensim for fast training.\n",
      "  \"C extension not loaded, training will be slow. \"\n",
      "D:\\Program Files\\Anaconda3\\lib\\site-packages\\ipykernel\\__main__.py:15: DeprecationWarning: Call to deprecated `__getitem__` (Method will be removed in 4.0.0, use self.wv.__getitem__() instead).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(294304, 50)\n"
     ]
    }
   ],
   "source": [
    "embedding_dim = 50\n",
    "min_count_num=5\n",
    "fname = '../data/word2vec.model'\n",
    "data_all = get_word2vec(data_all,fname,min_count_num,embedding_dim=embedding_dim)\n",
    "print(np.shape(data_all))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "静待训练词向量……"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_all = data[\"类别名字\"].tolist()\n",
    "\n",
    "def change_label(data):\n",
    "    dict_all = {}\n",
    "    set_all = list(set(data))\n",
    "    for i,content in enumerate(set_all):\n",
    "        dict_all[content] = i\n",
    "    return dict_all\n",
    "dict_all = change_label(label_all)\n",
    "label_id = [dict_all[i] for i in label_all]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(294304, 50)\n",
      "(294304,)\n",
      "{'体育': 0, '教育': 12, '军事': 2, '房产': 3, '旅游': 6, '民生': 7, '财经': 8, '科技': 9, '电竞': 10, '汽车': 5, '文化': 11, '国际': 1, '娱乐': 13, '农业': 4}\n"
     ]
    }
   ],
   "source": [
    "print(np.shape(data_all))\n",
    "print(np.shape(label_id))\n",
    "print(dict_all)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练集和测试集的划分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(235443, 50)\n",
      "(58861, 50)\n",
      "(235443,)\n",
      "(58861,)\n"
     ]
    }
   ],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "train_data,test_data,train_label,test_label = train_test_split(data_all,label_id,train_size=0.8,test_size=0.2)\n",
    "print(np.shape(train_data))\n",
    "print(np.shape(test_data))\n",
    "print(np.shape(train_label))\n",
    "print(np.shape(test_label))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型部分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Program Files\\Anaconda3\\lib\\site-packages\\sklearn\\linear_model\\logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
      "  FutureWarning)\n",
      "D:\\Program Files\\Anaconda3\\lib\\site-packages\\sklearn\\linear_model\\logistic.py:459: FutureWarning: Default multi_class will be changed to 'auto' in 0.22. Specify the multi_class option to silence this warning.\n",
      "  \"this warning.\", FutureWarning)\n"
     ]
    }
   ],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score\n",
    "\n",
    "model = LogisticRegression()\n",
    "model.fit(train_data,train_label)\n",
    "pred_result = model.predict(test_data)\n",
    "\n",
    "train_pred_result = model.predict(train_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "训练集结果：\n",
      "准确率： 0.7021189842127394\n",
      "紧缺率： 0.6902410995433323\n",
      "召回率： 0.6709492466380821\n",
      "f1值： 0.6787378012150175\n",
      "测试集结果：\n",
      "准确率： 0.6988668218344914\n",
      "紧缺率： 0.686264148358806\n",
      "召回率： 0.6680965871272807\n",
      "f1值： 0.6752743529695983\n"
     ]
    }
   ],
   "source": [
    "print(\"训练集结果：\")\n",
    "accuracy = accuracy_score(y_pred=train_pred_result,y_true=train_label)\n",
    "precision = np.average(precision_score(y_pred=train_pred_result,y_true=train_label,average=None))\n",
    "recall = np.average(recall_score(y_pred=train_pred_result,y_true=train_label,average=None))\n",
    "f1score = np.average(f1_score(y_pred=train_pred_result,y_true=train_label,average=None))\n",
    "print(\"准确率：\",accuracy)\n",
    "print(\"紧缺率：\",precision)\n",
    "print(\"召回率：\",recall)\n",
    "print(\"f1值：\",f1score)\n",
    "\n",
    "print(\"测试集结果：\")\n",
    "accuracy = accuracy_score(y_pred=pred_result,y_true=test_label)\n",
    "precision = np.average(precision_score(y_pred=pred_result,y_true=test_label,average=None))\n",
    "recall = np.average(recall_score(y_pred=pred_result,y_true=test_label,average=None))\n",
    "f1score = np.average(f1_score(y_pred=pred_result,y_true=test_label,average=None))\n",
    "print(\"准确率：\",accuracy)\n",
    "print(\"紧缺率：\",precision)\n",
    "print(\"召回率：\",recall)\n",
    "print(\"f1值：\",f1score)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "请输入待分类的文本：海边看日落\n",
      "当前文本类别为: 旅游\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Program Files\\Anaconda3\\lib\\site-packages\\ipykernel\\__main__.py:23: DeprecationWarning: Call to deprecated `__getitem__` (Method will be removed in 4.0.0, use self.wv.__getitem__() instead).\n"
     ]
    }
   ],
   "source": [
    "input_data = input(\"请输入待分类的文本：\")\n",
    "\n",
    "with open(stopword_path,\"r\",encoding=\"utf-8\") as f:\n",
    "    stopwords = f.readlines()\n",
    "\n",
    "stopwords = [i.strip() for i in stopwords]\n",
    "\n",
    "def segment(data):\n",
    "    data = list(jieba.cut(data))\n",
    "    data = [i for i in data if i not in stopwords]\n",
    "    return data\n",
    "\n",
    "\n",
    "def get_sent_word2vec(sent,seq_length=30):\n",
    "    \n",
    "    seg_data = segment(sent)\n",
    "    word2vec_model = Word2Vec.load(\"../data/word2vec.model\")\n",
    "    \n",
    "    unk_vector = np.zeros(shape=[embedding_dim,])\n",
    "    sent_matrix = []\n",
    "    for word in sent:\n",
    "        try:\n",
    "            temp_vector = word2vec_model[word]\n",
    "        except KeyError:\n",
    "            temp_vector = unk_vector\n",
    "        sent_matrix.append(temp_vector)\n",
    "    tensor = np.reshape(np.mean(sent_matrix,axis=0),(1,embedding_dim))\n",
    "    return tensor\n",
    "    \n",
    "\n",
    "\n",
    "data_vector = get_sent_word2vec(input_data)\n",
    "\n",
    "pred_result = model.predict(data_vector)\n",
    "dict_all_reverse = dict(zip(dict_all.values(),dict_all.keys()))\n",
    "print(\"当前文本类别为:\",dict_all_reverse[pred_result[0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
